/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.subcommands;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringTokenizer;

import mosdi.distributions.PoissonDistribution;
import mosdi.fa.Alphabet;
import mosdi.fa.CDFA;
import mosdi.fa.DFAFactory;
import mosdi.fa.FiniteMemoryTextModel;
import mosdi.fa.GeneralizedString;
import mosdi.fa.GeneralizedStringsHammingNFA;
import mosdi.fa.GeneralizedStringsNFA;
import mosdi.fa.IIDTextModel;
import mosdi.fa.MarkovianTextModel;
import mosdi.fa.NFA;
import mosdi.fa.NodeLimitExceededException;
import mosdi.fa.PrositeMotifParser;
import mosdi.paa.ClumpSizeCalculator;
import mosdi.paa.MatchCountDAA;
import mosdi.paa.PAA;
import mosdi.paa.TextBasedPAA;
import mosdi.util.Distributions;
import mosdi.util.FileUtils;
import mosdi.util.Iupac;
import mosdi.util.Log;
import mosdi.util.NamedSequence;
import mosdi.util.SequenceUtils;
import mosdi.util.iterators.DFALanguageIterator;
import mosdi.util.iterators.LexicographicalIterator;

public class StatSubcommand extends Subcommand {

	private enum CountingScheme { OVERLAPPING, NON_OVERLAPPING, MATCH_POSITIONS }
	private enum PatternFormat { IUPAC, PROSITE };
	private enum Mode { EXACT, COMP_POISSON, POISSON };

	@Override
	public String usage() {
		return
		super.usage()+" [options] <mode> <pattern>\n" +
		"\n" +
		"  <mode>   \"exact\"        (exact statistics, slowest),\n" +
		"           \"comp-poisson\" (compound Poisson approximation, faster\n" +
		"                           and still very accurate for rare patterns),\n" +
		"           \"poisson\"      (Poisson approximation, fastest but less\n" +
		"                           accurate for self-overlapping patterns)\n" +
		"  <pattern> pattern description (set format with -f, default is IUPAC).\n" +
		"            When a comma-separated list of multiple patterns is given,\n" +
		"            an instance of any of those is considered a match.\n" +
		"            If -F option is given, <pattern> is interpreted as name\n" +
		"            of file containing multiple patterns (one per line)\n" +
		"\n" +
		"Options controlling the pattern:\n" +
		"  -F: read patterns from file (one pattern per line)\n" +
		"  -f <pattern-format>: iupac (default), prosite (sets used alphabet accordingly)\n" +
		"  -r: simultaneously consider reverse complementary motif (only for iupac)\n" +
		"  -c <O|N|M>: set counting scheme: (O)verlapping (default),\n" +
		"              (N)on-Overlapping, (M)atch positions\n" +
		"  -H <hamming-distance>: add Hamming neighborhood to patterns\n" +
		"\n" +
		"Options for mode \"exact\"\n" +
		"  -D: use doubling algorithm\n" +
		"  -C: auto-choose algorithm (standard/doubling) according to runtime estimates\n" +
		"\n" +
		"Options for mode \"comp-poisson\"\n" +
		"  -s <max-clump-size>: size of clump size distribution (default:30)\n" +
		"                       Accuracy decreases if choosen too small for strongly\n" +
		"                       overlapping patterns.\n" +
		"\n" +
		"Other options:\n" +
		"  -n <number-of-steps>: length of random text (default: 100)\n" +
		"  -m <maxmatches>: maximal number of matches (default: 10)\n" +
		"  -p <fasta-file>: read sequence from file and calculate p-value\n" +
		"                   Sets options -n and -m accordingly. If -q is not\n" +
		"                   given, the text model is estimated from given sequences.\n" +
		"  -V <order>: use Markovian text model of given order (to be used with option -p)\n" +
		"  -q <q-gram-table-file>: Estimate text model from given q-gram table (default: uniform i.i.d.)\n" +
		"                          A q-gram table can be created by using \"mosdi-utils count-qgrams\".\n" +
		"  -N <nodelimit>: if exceeded, construction of automaton will aborted (default: 1000)\n" +
		"  -g: compute the number of different recognized strings";
	}

	@Override
	public String description() {
		return "Calculates (exact or approximate) distribution of a pattern's occurrence count.";
	}

	@Override
	public String name() {
		return "count-dist";
	}

	@Override
	public int run(String[] args) {
		parseOptions(args, 2, "f:Fn:m:c:p:N:rDCV:q:H:s:g");

		// Option dependencies
		exclusiveOptions("C", "D");
		exclusiveOptions("p", "n");
		exclusiveOptions("p", "m");
		exclusiveOptions("V", "q");
		impliedOptions("V", "p");

		// Mandatory arguments
		Mode mode = getEnumArgument(0, Mode.values(), "exact", "comp-poisson", "poisson");
		String patternArgument = getStringArgument(1);

		// Options
		long steps = getNonNegativeLongOption("n", 100);
		int maxMatches = getNonNegativeIntOption("m", 10);
		CountingScheme countingScheme = getEnumOption("c", CountingScheme.OVERLAPPING, CountingScheme.values(), "O", "N", "M");
		boolean doubling = getBooleanOption("D", false);
		PatternFormat patternFormat = getEnumOption("f", PatternFormat.IUPAC, PatternFormat.values(), "iupac", "prosite");
		int nodeLimit = getPositiveIntOption("N", 1000);
		boolean considerReverse = getBooleanOption("r", false);
		boolean autoChooseAlgorithm = getBooleanOption("C", false);
		int textModelOrder = getNonNegativeIntOption("V", 0);
		int hammingDistance = getNonNegativeIntOption("H", 0);
		int maxClumpSize = getPositiveIntOption("s", 30);
		boolean readPatternsFromFile = getBooleanOption("F", false);
		String fastaFilename = getStringOption("p", null);
		String qGramTableFilename = getStringOption("q", null);
		boolean computeStringCount = getBooleanOption("g", false);
		
		// create list of patterns
		List<String> patternList = null;
		if (readPatternsFromFile) {
			// read patterns from file
			patternList = FileUtils.readPatternFile(patternArgument);
		} else {
			patternList = new ArrayList<String>(1);
			patternList.add(patternArgument);
		}

		Alphabet alphabet = null;
		switch (patternFormat) {
		case IUPAC: alphabet = Alphabet.getDnaAlphabet(); break;
		case PROSITE: alphabet = Alphabet.getAminoAcidAlphabet(); break;
		}

		FiniteMemoryTextModel textModel = null;
		List<int[]> sequences = null;
		if (fastaFilename!=null) {
			try {
				List<NamedSequence> namedSequences = SequenceUtils.readFastaFile(fastaFilename, alphabet);
				sequences = SequenceUtils.sequenceList(namedSequences);
			} catch (Exception e) {
				Log.errorln(e.toString());
				System.exit(1);
			}
			if (qGramTableFilename==null) {
				if (textModelOrder==0) {
					textModel = new IIDTextModel(alphabet.size(), sequences);
				} else {
					textModel = new MarkovianTextModel(textModelOrder, alphabet.size(), sequences);
				}
			}
			if (sequences.size()==1) {
				steps = sequences.get(0).length;
			} else {
				if (doubling) {
					Log.errorln("Error: Option -D: Doubling algorithm cannot be used for multiple sequences ("+sequences.size()+" found in fasta file).");
					System.exit(1);
				}
			}
		}

		if (considerReverse) {
			if (patternFormat==PatternFormat.PROSITE) {
				Log.errorln("Error: Option -r not usable for PROSITE patterns.");
				System.exit(1);
			}
		}
		if (qGramTableFilename!=null) {
			try {
				textModel = SequenceUtils.buildTextModelFromQGramFile(qGramTableFilename);
			} catch (Exception e) {
				Log.errorln(e.toString());
				System.exit(1);
			}
		}

		if (textModel==null) {
			textModel = new IIDTextModel(alphabet.size());
		}
		
		Log.printf(Log.Level.VERBOSE, "Text model states: \"%d\"%n", textModel.getStateCount());
		if (Log.levelAtLeast(Log.Level.DEBUG)) {
			Log.print(Log.Level.DEBUG, textModel.toString());
		}
		
		int[] sequenceLengths = null;
		if ((sequences!=null) && (mode==Mode.EXACT)) {
			sequenceLengths = new int[sequences.size()];
			int n = 0;
			for (int[] s : sequences) sequenceLengths[n++] = s.length;
			Arrays.sort(sequenceLengths);
		}
		
		Log.println(Log.Level.STANDARD, "Format: >>p_value>> p-value >>stats>> pattern config #generalized-strings #recognized-strings #dfa-states #dfa-states-minimized #paa-states #textModel-states #matches constant-factor expectation expected-clump-size >>runtimes>> dfa-construction minimization matching statistics total >>dist>> occurrence-count-distribution ");
		// MAIN LOOP: iterate over all patterns
		for (String pattern : patternList) {
			boolean skip = false;
			double timeCdfaConstruction = -1.0;
			double timeMinimization = -1.0;
			double timeMatching = -1.0;
			double timeStatistics = -1.0;
			double timeTotal = -1.0;
			int states = -1;
			int statesMinimal = -1;
			int paaStates = -1;
			int matches = -1;
			double expectation = -1.0;
			double expectedClumpSize = -1.0;
			// the factor hidden behind the O() notation
			double constantFactor = -1.0;
			int recognizedStrings = -1;
			Log.startTimer();
			// construct automaton
			CDFA cdfa = null;
			List<GeneralizedString> genStringList = new ArrayList<GeneralizedString>();
			StringTokenizer st = new StringTokenizer(pattern, ",", false);
			while (st.hasMoreTokens()) {
				if (patternFormat==PatternFormat.IUPAC) {
					genStringList.addAll(Iupac.toGeneralizedStrings(st.nextToken(), considerReverse)); 
				}
				if (patternFormat==PatternFormat.PROSITE) {
					genStringList.addAll(PrositeMotifParser.parse(st.nextToken()));
				}
			}
			if ((mode!=Mode.EXACT) || computeStringCount) {
				int l = genStringList.get(0).length();
				for (GeneralizedString gs : genStringList) {
					if (gs.length()!=l) throw new IllegalStateException("Not implemented for generalized strings for different lengths.");
				}
			}
			try {
				NFA nfa;
				if (hammingDistance==0) {
					nfa = new GeneralizedStringsNFA(genStringList); 
				} else {
					nfa = new GeneralizedStringsHammingNFA(genStringList, hammingDistance);
				}
				cdfa = DFAFactory.build(alphabet, nfa, nodeLimit);
				timeCdfaConstruction = Log.getLastPeriodCpu();
				states = cdfa.getStateCount();
				Log.printf(Log.Level.VERBOSE, "DFA states: \"%d\"%n", states);
				if (countingScheme==CountingScheme.NON_OVERLAPPING) cdfa = cdfa.toNonOverlapping();
				if (countingScheme==CountingScheme.MATCH_POSITIONS) cdfa = cdfa.toMatchPositionCount();
				cdfa = cdfa.minimizeHopcroft();
				timeMinimization = Log.getLastPeriodCpu();
				statesMinimal = cdfa.getStateCount();
				Log.printf(Log.Level.VERBOSE, "DFA states (after minimization): \"%d\"%n", statesMinimal);
			} catch (NodeLimitExceededException e) {
				Log.println(Log.Level.STANDARD, "Skipping: node limit exceeded (switch -N)");
				skip = true;
			}
			if (computeStringCount && !skip) {
				int l = genStringList.get(0).length();
				LexicographicalIterator iterator = new DFALanguageIterator(cdfa,l,alphabet.size());
				int i = 0;
				while (iterator.hasNext()) {
					iterator.next();
					i += 1;
				}
				recognizedStrings = i;
				Log.printf(Log.Level.STANDARD, "CDFA recognizes %d strings of length %d\n", i, l);
			}

			if (Log.levelAtLeast(Log.Level.DEBUG)) {
				Log.print(Log.Level.DEBUG, cdfa.toString());
			}
			
			// if calculating p-value, then perform matching and set max_matches accordingly
			if (sequences!=null) {
				Log.startTimer();
				matches = 0;
				for (int[] s : sequences) {
					matches += cdfa.countMatches(s);
				}
				Log.stopTimer("DFA matching");
				timeMatching = Log.getLastPeriodCpu();
				Log.printf(Log.Level.STANDARD, "Found %d matches in %d sequence%s%n", matches, sequences.size(), (sequences.size()>1?"s":""));
				maxMatches = matches;
			}

			// resulting distribution of matches we are interested in
			double[] matchDistribution = null;
			if (!skip) {
				if (mode==Mode.EXACT) {
					MatchCountDAA daa = new MatchCountDAA(cdfa, maxMatches);
					PAA paa = new TextBasedPAA(daa, textModel);
					paaStates = paa.getStateCount();
					Log.printf(Log.Level.VERBOSE, "PAA has %d states.%n", paaStates);
					if (Log.levelAtLeast(Log.Level.DEBUG)) {
						Log.print(Log.Level.DEBUG, paa.toString());
					}
					if (autoChooseAlgorithm) {
						double estimateStandard = 1e-8*((double)steps)*paa.getStateCount()*alphabet.size()*(maxMatches+1);
						double estimateDoubling = 2.5e-8*Math.log(steps)*Math.pow(paa.getStateCount(), 3.0)*Math.pow(maxMatches+1,2.0);
						doubling = estimateDoubling<estimateStandard;
						Log.printf(Log.Level.STANDARD, "Algorithm-chooser: estimate standard/doubling: %e/%e --> choosing %s%n",
								estimateStandard, estimateDoubling, doubling?"DOUBLING":"STANDARD");
					}
					
					double[][] stateValueDistribution;
					if ((sequences!=null) && (sequences.size()>1)){
						Log.startTimer();
						stateValueDistribution = paa.stateValueStartDistribution();
						double[] valueDistribution = null;
						int n = 0;
						for (int length : sequenceLengths) {
							if (length>n) {
								for (; n<length; ++n) {
									stateValueDistribution = paa.updateStateValueDistribution(stateValueDistribution);
								}
								valueDistribution = paa.toValueDistribution(stateValueDistribution);
							}
							if (matchDistribution==null) {
								matchDistribution = valueDistribution;
							} else {
								matchDistribution = Distributions.convolveLengthPreserving(matchDistribution, valueDistribution, true);
							}
						}
						Log.stopTimer("Computing PAA state-value distribution for multiple sequences (LINEAR)");
						timeStatistics = Log.getLastPeriodCpu();
					} else {
						if (doubling) {
							Log.startTimer();
							stateValueDistribution = paa.stateValueDistributionViaDoubling(steps);
							Log.stopTimer("Computing PAA state-value distribution (DOUBLING)");
							constantFactor=1.0/(Math.log(steps)*Math.pow(statesMinimal, 3.0)*Math.pow(maxMatches+1,2.0));
						} else {
							Log.startTimer();
							stateValueDistribution = paa.computeStateValueDistribution(steps);
							Log.stopTimer("Computing PAA state-value distribution (LINEAR)");
							constantFactor=1.0/(((double)steps)*statesMinimal*alphabet.size()*(maxMatches+1));
						}
						matchDistribution = paa.toValueDistribution(stateValueDistribution);
						timeStatistics = Log.getLastPeriodCpu();
						constantFactor*=timeStatistics;
					}
				}
				if (mode==Mode.COMP_POISSON) {
					Log.startTimer();
					Log.startTimer();
					int length = genStringList.get(0).length();
					double singleExpectation = SequenceUtils.computeExpectation(cdfa, textModel, length);
					expectation = 0.0;
					if (sequences==null) {
						expectation = singleExpectation*(steps-length+1);
					} else {
						for (int[] s : sequences) {
							expectation+=singleExpectation*(s.length-length+1);
						}
					}
					Log.restartTimer("Computing expectation");
					ClumpSizeCalculator csc = new ClumpSizeCalculator(textModel, cdfa, length);
					paaStates = csc.getProductStateCount();
					double[] clumpSizeDist = csc.clumpSizeDistribution(maxClumpSize, 1e-300);
					Log.restartTimer("Calculating clump size distribution");
					// calculate expected clump size
					expectedClumpSize = 0.0;
					for (int i=1; i<clumpSizeDist.length; ++i) {
						expectedClumpSize+=clumpSizeDist[i]*i;
					}
					Log.printf(Log.Level.VERBOSE, "Clump size distribution: %s\n", Arrays.toString(clumpSizeDist));
					Log.printf(Log.Level.VERBOSE, "Expected clump size: %e\n", expectedClumpSize);
					double lambda = expectation/expectedClumpSize;
					PoissonDistribution poissonDist = new PoissonDistribution(lambda);
					matchDistribution = poissonDist.compoundPoissonDistribution(clumpSizeDist, maxMatches);
					Log.stopTimer("Distribution/p-value calculation by convolution");
					Log.stopTimer("Distribution by compound Poisson");
					timeStatistics = Log.getLastPeriodCpu();
				}
				if (mode==Mode.POISSON) {
					Log.startTimer();
					int length = genStringList.get(0).length();
					double singleExpectation = SequenceUtils.computeExpectation(cdfa, textModel, length);
					expectation = 0.0;
					if (sequences==null) {
						expectation = singleExpectation*(steps-length+1);
					} else {
						for (int[] s : sequences) {
							expectation+=singleExpectation*(s.length-length+1);
						}
					}
					PoissonDistribution poissonDist = new PoissonDistribution(expectation);
					matchDistribution = poissonDist.get(maxMatches+1, true);
					Log.stopTimer("Computing Poisson distribution");
					timeStatistics = Log.getLastPeriodCpu();
					
				}
			}
			Log.stopTimer("total");
			timeTotal = Log.getLastPeriodCpu();

			StringBuilder sb = new StringBuilder();
			if (skip) {
				sb.append(">>p_value>> -1.0 ");
			} else {
				sb.append(String.format(">>p_value>> %e ", matchDistribution[maxMatches]));
			}
			sb.append(String.format(">>stats>> %s %s %d %d %d %d %d %d %d %e %e %e ", pattern, countingScheme.toString().charAt(0), genStringList.size(), recognizedStrings, states, statesMinimal, paaStates, textModel.getStateCount(), matches, constantFactor, expectation, expectedClumpSize));
			sb.append(Log.format(">>runtimes>> %t %t %t %t %t ", timeCdfaConstruction, timeMinimization, timeMatching, timeStatistics, timeTotal));
			if (!skip) {
				sb.append(">>dist>> ");
				for (int i=0; i<=maxMatches; ++i) {
					sb.append(String.format("%e ", matchDistribution[i]));
				}
			}
			if (sb.length()>0) Log.println(Log.Level.STANDARD, sb.toString());
			Log.println(Log.Level.STANDARD, "=============================");
		}
		return 0;
	}
}


